diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index 78fea41662e49..2de94fdcbb193 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform import unittest import numpy as np @@ -171,28 +170,24 @@ def check_prim(self, net, use_prim): self.assertTrue('layer_norm' not in fwd_ops) def test_cinn_prim(self): - plat = platform.system() - if plat == "Linux": - for dtype in self.dtypes: - if paddle.device.get_device() == "cpu": - print("need pass this case") - continue - x_n, w_n, b_n = generate_data(dtype) - self.x = paddle.to_tensor(x_n) - self.w = paddle.to_tensor(w_n) - self.b = paddle.to_tensor(b_n) - self.x.stop_gradient = False - dy_res = self.train(use_prim=False) - cinn_res = self.train(use_prim=True) - - np.testing.assert_allclose( - cinn_res, - dy_res, - rtol=TOLERANCE[dtype]['rtol'], - atol=TOLERANCE[dtype]['atol'], - ) - else: - pass + for dtype in self.dtypes: + if paddle.device.get_device() == "cpu": + print("need pass this case") + continue + x_n, w_n, b_n = generate_data(dtype) + self.x = paddle.to_tensor(x_n) + self.w = paddle.to_tensor(w_n) + self.b = paddle.to_tensor(b_n) + self.x.stop_gradient = False + dy_res = self.train(use_prim=False) + cinn_res = self.train(use_prim=True) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py index ff433f439e056..ae2de19c8721d 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform import unittest import numpy as np @@ -185,31 +184,24 @@ def check_prim(self, net, use_prim): self.assertTrue('reduce_mean' not in fwd_ops) def test_cinn_prim(self): - plat = platform.system() - if plat == "Linux": - for shape in self.shapes: - for dtype in self.dtypes: - # mean-kernel on cpu not support float16 - if ( - paddle.device.get_device() == "cpu" - and dtype == "float16" - ): - print("need pass this case") - continue - data = generate_data(shape, dtype) - data_t = paddle.to_tensor(data) - data_t.stop_gradient = False - dy_res = self.train(use_prim=False, data=data_t) - cinn_res = self.train(use_prim=True, data=data_t) - - np.testing.assert_allclose( - cinn_res, - dy_res, - rtol=TOLERANCE[dtype]['rtol'], - atol=TOLERANCE[dtype]['atol'], - ) - else: - pass + for shape in self.shapes: + for dtype in self.dtypes: + # mean-kernel on cpu not support float16 + if paddle.device.get_device() == "cpu" and dtype == "float16": + print("need pass this case") + continue + data = generate_data(shape, dtype) + data_t = paddle.to_tensor(data) + data_t.stop_gradient = False + dy_res = self.train(use_prim=False, data=data_t) + cinn_res = self.train(use_prim=True, data=data_t) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index 407e11349c2de..fed94b4003020 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -14,7 +14,6 @@ import math import os -import platform import tempfile import time import unittest @@ -442,22 +441,18 @@ def test_resnet_composite_backward(self): ) def test_resnet_composite_forward_backward(self): - plat = platform.system() - if plat == "Linux": - core._set_prim_all_enabled(True) - static_loss = self.train(to_static=True) - core._set_prim_all_enabled(False) - dygraph_loss = self.train(to_static=True) - np.testing.assert_allclose( - static_loss, - dygraph_loss, - rtol=1e-02, - err_msg='static_loss: {} \n dygraph_loss: {}'.format( - static_loss, dygraph_loss - ), - ) - else: - pass + core._set_prim_all_enabled(True) + static_loss = self.train(to_static=True) + core._set_prim_all_enabled(False) + dygraph_loss = self.train(to_static=True) + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-02, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) def test_in_static_mode_mkldnn(self): fluid.set_flags({'FLAGS_use_mkldnn': True}) diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py index 32e83c4b2abe7..3d5fec6ed2ed7 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform import time import unittest @@ -63,9 +62,7 @@ def train(to_static, enable_prim, enable_cinn): np.random.seed(SEED) paddle.seed(SEED) paddle.framework.random._manual_program_seed(SEED) - fluid.core._set_prim_all_enabled( - enable_prim and platform.system() == 'Linux' - ) + fluid.core._set_prim_all_enabled(enable_prim) train_reader = paddle.batch( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py index 309959747e064..b2e2ad05ea439 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags_case.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import platform import unittest import paddle @@ -95,78 +94,48 @@ def test_cinn_prim_all(self): self.reset_env_flag() os.environ["FLAGS_prim_all"] = "True" self.flag = "cinn_prim_all" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=True) - else: - pass + _ = self.train(use_cinn=True) def test_prim_all(self): """prim forward + prim backward""" self.reset_env_flag() os.environ["FLAGS_prim_all"] = "True" self.flag = "prim_all" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=False) - else: - pass + _ = self.train(use_cinn=False) def test_cinn_prim_forward(self): """cinn + prim forward""" - self.reset_env_flag() - os.environ["FLAGS_prim_forward"] = "True" self.flag = "cinn_prim_forward" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=True) - else: - pass + _ = self.train(use_cinn=True) def test_prim_forward(self): """only prim forward""" self.reset_env_flag() os.environ["FLAGS_prim_forward"] = "True" self.flag = "prim_forward" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=False) - else: - pass + _ = self.train(use_cinn=False) def test_cinn_prim_backward(self): """cinn + prim_backward""" self.reset_env_flag() os.environ["FLAGS_prim_backward"] = "True" self.flag = "cinn_prim_backward" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=True) - else: - pass + _ = self.train(use_cinn=True) def test_prim_backward(self): """only prim backward""" self.reset_env_flag() os.environ["FLAGS_prim_backward"] = "True" self.flag = "prim_backward" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=False) - else: - pass + _ = self.train(use_cinn=False) def test_cinn(self): """only cinn""" self.reset_env_flag() self.flag = "cinn" - plat = platform.system() - if plat == "Linux": - _ = self.train(use_cinn=True) - else: - pass + _ = self.train(use_cinn=True) if __name__ == '__main__':