Commit 49c8f142 by xlwang

Merge branch 'master' of github.com:xlwang233/pytorch-DCRNN

parents 5bd00a8b 0c6f4eb7
......@@ -9,90 +9,31 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent
## Requirements
- scipy>=0.19.0
- numpy>=1.12.1
- pandas>=0.19.2
- torch>=0.4.x
- scipy=1.2.1
- numpy=1.16.2
- pandas=0.24.2
- torch>=1.1.0
- tqdm
- pytable
## Data Preparation
The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY), i.e., `metr-la.h5` and `pems-bay.h5`, are available at [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g), and should be
put into the `data/` folder.
The `*.h5` files store the data in `panads.DataFrame` using the `HDF5` file format. Here is an example:
| | sensor_0 | sensor_1 | sensor_2 | sensor_n |
|:-------------------:|:--------:|:--------:|:--------:|:--------:|
| 2018/01/01 00:00:00 | 60.0 | 65.0 | 70.0 | ... |
| 2018/01/01 00:05:00 | 61.0 | 64.0 | 65.0 | ... |
| 2018/01/01 00:10:00 | 63.0 | 65.0 | 60.0 | ... |
| ... | ... | ... | ... | ... |
Here is an article about [Using HDF5 with Python](https://medium.com/@jerilkuriakose/using-hdf5-with-python-6c5242d08773).
Run the following commands to generate train/test/val dataset at `data/{METR-LA,PEMS-BAY}/{train,val,test}.npz`.
```bash
# Create data directories
mkdir -p data/{METR-LA,PEMS-BAY}
# METR-LA
python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5
# PEMS-BAY
python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5
```
## Graph Construction
As the currently implementation is based on pre-calculated road network distances between sensors, it currently only
supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`).
```bash
python -m scripts.gen_adj_mx --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
--output_pkl_filename=data/sensor_graph/adj_mx.pkl
```
Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv).
## Run the Pre-trained Model on METR-LA
```bash
# METR-LA
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml
# PEMS-BAY
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
```
The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
For data preparation, check the original repo:[liyaguang/DCRNN](https://github.com/liyaguang/DCRNN)
## Model Training
For now, training is only supported for METR-LA dataset due to data availability.
```bash
# METR-LA
python train_DCRNN.py
python train.py --config config.json
```
Each epoch takes about 2min(~ 130 seconds) on a single RTX 2080 Ti for METR-LA.
Each epoch takes about 5-6min(~ 340 seconds) on a single RTX 2080 Ti for METR-LA.
There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule.
## Eval baseline methods
```bash
# METR-LA
python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5
```
More details are being added ...
## Citation
If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper:
```
@inproceedings{li2018dcrnn_traffic,
title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
booktitle={International Conference on Learning Representations (ICLR '18)},
year={2018}
}
```
## Log and Model Savings
Log information will be saved at `saved/log/.../info.log`
The best validated model will be saved at `saved/model/.../model_best.pth`
The best results that I obtained so far is shown in `test_results.log`
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment