Towards fast weak adversarial training to solve high dimensional parabolic partial differential equations using XNODE-WAN
Due to the curse of dimensionality, solving high dimensional parabolic partial differential equations (PDEs) has been a challenging problem for decades. Recently, a weak adversarial network (WAN) proposed in [1] offers a flexible and computationally efficient approach to tackle this problem defined on arbitrary domains by leveraging the weak solution. WAN reformulates the PDE problem as a generative adversarial network, where the weak solution (primal network) and the test function (adversarial network) are parameterized by multi-layer deep neural networks (DNNs).
In our work, we design a novel so-called XNODE model for a universal and effective representation for the parabolic PDE solution. Built on the neural ODE model, XNODE model is able to incoporate the priori information of the PDEs to the primal netwrok. The proposed hybrid method (XNODE-WAN) by integrating the XNODE model within the WAN framework leads to significant improvement on the performance and efficiency of training. Numerical results show that our method can reduce the training time to a fraction of that of the WAN model.
More specifically, our XNODE-WAN algorithm aims to solve the following BVP PDE on either time-indepedent or time-varying -dimensional domain :
where denotes the spatial domain of when restricting time to be .
This repository is the official implementation of the paper entitled "Towards fast weak adversarial training to solve high dimensional parabolic partial differential equations using XNODE-WAN" (arxiv).
Requirements for a successful implementation of the codes can be found in requirements.txt
.
To solve a PDE one can input all the known functions of the problem in the file main.py
which will run the algorithm. An example of this in action can be found in the example.ipynb
file in which our test problem from the paper is implemented.
In the config
dictionary one can specify the hyperparameters to be used by the algorithm to solve the problem.
In the setup
dictionary the problem specific information is included:
dim
: the dimension of our problem's domain (excluding time)N_t
: the number of time points sampled for each pathN_r
: the number of paths in the interior of the domainN_b
: the number of paths on the boundary of the domainT0
: the minimum time at which our domain existsT
: the maximum time at which our domain exists
Note that the data structure is [N, L, C]
where N
is the number of different points, L
is the number of time points at which they are evaluated (these have to be the same for all points) and C
is the axis of the dimensions where time is the top dimension ([:, :, 0]
is the index for all times).
There are certain complications in evaluating points and therefore the best method is to feed the network points to be evaluated individually. The form for single points (to ensure accurate computations) needs to be torch.tensor([[x0, x]])
where x
is the point you want to evaluate and x0
has the same coordinates in but the time coordinate is that at which this point in space is on (this includes all points with time T0
).
The algorithm supports a wide variety of domains, including time-varying ones, and these can be specified in the dataset.py
file which already contains some examples. It is important to conform to the structure highlighted in this file to guarantee that the algorithm works.