1+ import os .path as osp
2+ import torch
3+ import numpy as np
4+ from torch_geometric .data import InMemoryDataset , Data
5+ from torch_geometric .io import read_txt_array
6+ import torch .nn .functional as F
7+
8+ import scipy
9+ import pickle as pkl
10+ import csv
11+ import json
12+
13+ import warnings
14+ warnings .filterwarnings ('ignore' , category = DeprecationWarning )
15+
16+
17+ class CitationDataset (InMemoryDataset ):
18+ def __init__ (self ,
19+ root ,
20+ name ,
21+ transform = None ,
22+ pre_transform = None ,
23+ pre_filter = None ):
24+ self .name = name
25+ self .root = root
26+ super (CitationDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
27+
28+ self .data , self .slices = torch .load (self .processed_paths [0 ])
29+
30+ @property
31+ def raw_file_names (self ):
32+ return ["docs.txt" , "edgelist.txt" , "labels.txt" ]
33+
34+ @property
35+ def processed_file_names (self ):
36+ return ['data.pt' ]
37+
38+ def download (self ):
39+ pass
40+
41+ def process (self ):
42+ edge_path = osp .join (self .raw_dir , '{}_edgelist.txt' .format (self .name ))
43+ edge_index = read_txt_array (edge_path , sep = ',' , dtype = torch .long ).t ()
44+
45+ docs_path = osp .join (self .raw_dir , '{}_docs.txt' .format (self .name ))
46+ f = open (docs_path , 'rb' )
47+ content_list = []
48+ for line in f .readlines ():
49+ line = str (line , encoding = "utf-8" )
50+ content_list .append (line .split ("," ))
51+ x = np .array (content_list , dtype = float )
52+ x = torch .from_numpy (x ).to (torch .float )
53+
54+ label_path = osp .join (self .raw_dir , '{}_labels.txt' .format (self .name ))
55+ f = open (label_path , 'rb' )
56+ content_list = []
57+ for line in f .readlines ():
58+ line = str (line , encoding = "utf-8" )
59+ line = line .replace ("\r " , "" ).replace ("\n " , "" )
60+ content_list .append (line )
61+ y = np .array (content_list , dtype = int )
62+ y = torch .from_numpy (y ).to (torch .int64 )
63+
64+ data_list = []
65+ data = Data (edge_index = edge_index , x = x , y = y )
66+
67+ random_node_indices = np .random .permutation (y .shape [0 ])
68+ training_size = int (len (random_node_indices ) * 0.8 )
69+ val_size = int (len (random_node_indices ) * 0.1 )
70+ train_node_indices = random_node_indices [:training_size ]
71+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
72+ test_node_indices = random_node_indices [training_size + val_size :]
73+
74+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
75+ train_masks [train_node_indices ] = 1
76+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
77+ val_masks [val_node_indices ] = 1
78+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
79+ test_masks [test_node_indices ] = 1
80+
81+ data .train_mask = train_masks
82+ data .val_mask = val_masks
83+ data .test_mask = test_masks
84+
85+ if self .pre_transform is not None :
86+ data = self .pre_transform (data )
87+
88+ data_list .append (data )
89+
90+ data , slices = self .collate ([data ])
91+
92+ torch .save ((data , slices ), self .processed_paths [0 ])
93+
94+
95+ class EllipticDataset (InMemoryDataset ):
96+ def __init__ (self ,
97+ root ,
98+ name ,
99+ transform = None ,
100+ pre_transform = None ,
101+ pre_filter = None ):
102+ self .name = name
103+ self .root = root
104+ super (EllipticDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
105+
106+ self .data , self .slices = torch .load (self .processed_paths [0 ])
107+
108+ @property
109+ def raw_file_names (self ):
110+ return [".pkl" ]
111+
112+ @property
113+ def processed_file_names (self ):
114+ return ['data.pt' ]
115+
116+ def download (self ):
117+ pass
118+
119+ def process (self ):
120+ path = osp .join (self .raw_dir , '{}.pkl' .format (self .name ))
121+ result = pkl .load (open (path , 'rb' ))
122+ A , label , features = result
123+ label = label + 1
124+ edge_index = torch .tensor (np .array (A .nonzero ()), dtype = torch .long )
125+ features = np .array (features )
126+ x = torch .from_numpy (features ).to (torch .float )
127+ y = torch .tensor (label ).to (torch .int64 )
128+
129+ data_list = []
130+ data = Data (edge_index = edge_index , x = x , y = y )
131+
132+ random_node_indices = np .random .permutation (y .shape [0 ])
133+ training_size = int (len (random_node_indices ) * 0.8 )
134+ val_size = int (len (random_node_indices ) * 0.1 )
135+ train_node_indices = random_node_indices [:training_size ]
136+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
137+ test_node_indices = random_node_indices [training_size + val_size :]
138+
139+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
140+ train_masks [train_node_indices ] = 1
141+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
142+ val_masks [val_node_indices ] = 1
143+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
144+ test_masks [test_node_indices ] = 1
145+
146+ data .train_mask = train_masks
147+ data .val_mask = val_masks
148+ data .test_mask = test_masks
149+
150+ if self .pre_transform is not None :
151+ data = self .pre_transform (data )
152+
153+ data_list .append (data )
154+
155+ data , slices = self .collate ([data ])
156+
157+ torch .save ((data , slices ), self .processed_paths [0 ])
158+
159+
160+ class TwitchDataset (InMemoryDataset ):
161+ def __init__ (self ,
162+ root ,
163+ name ,
164+ transform = None ,
165+ pre_transform = None ,
166+ pre_filter = None ):
167+ self .name = name
168+ self .root = root
169+ super (TwitchDataset , self ).__init__ (root , transform , pre_transform , pre_filter )
170+
171+ self .data , self .slices = torch .load (self .processed_paths [0 ])
172+
173+ @property
174+ def raw_file_names (self ):
175+ return ["edges.csv, features.json, target.csv" ]
176+
177+ @property
178+ def processed_file_names (self ):
179+ return ['data.pt' ]
180+
181+ def download (self ):
182+ pass
183+
184+ def load_twitch (self , lang ):
185+ assert lang in ('DE' , 'EN' , 'FR' ), 'Invalid dataset'
186+ filepath = self .raw_dir
187+ label = []
188+ node_ids = []
189+ src = []
190+ targ = []
191+ uniq_ids = set ()
192+ with open (f"{ filepath } /musae_{ lang } _target.csv" , 'r' ) as f :
193+ reader = csv .reader (f )
194+ next (reader )
195+ for row in reader :
196+ node_id = int (row [5 ])
197+ # handle FR case of non-unique rows
198+ if node_id not in uniq_ids :
199+ uniq_ids .add (node_id )
200+ label .append (int (row [2 ]== "True" ))
201+ node_ids .append (int (row [5 ]))
202+
203+ node_ids = np .array (node_ids , dtype = np .int32 )
204+
205+ with open (f"{ filepath } /musae_{ lang } _edges.csv" , 'r' ) as f :
206+ reader = csv .reader (f )
207+ next (reader )
208+ for row in reader :
209+ src .append (int (row [0 ]))
210+ targ .append (int (row [1 ]))
211+
212+ with open (f"{ filepath } /musae_{ lang } _features.json" , 'r' ) as f :
213+ j = json .load (f )
214+
215+ src = np .array (src )
216+ targ = np .array (targ )
217+ label = np .array (label )
218+
219+ inv_node_ids = {node_id :idx for (idx , node_id ) in enumerate (node_ids )}
220+ reorder_node_ids = np .zeros_like (node_ids )
221+ for i in range (label .shape [0 ]):
222+ reorder_node_ids [i ] = inv_node_ids [i ]
223+
224+ n = label .shape [0 ]
225+ A = scipy .sparse .csr_matrix ((np .ones (len (src )), (np .array (src ), np .array (targ ))), shape = (n ,n ))
226+ features = np .zeros ((n ,3170 ))
227+ for node , feats in j .items ():
228+ if int (node ) >= n :
229+ continue
230+ features [int (node ), np .array (feats , dtype = int )] = 1
231+ new_label = label [reorder_node_ids ]
232+ label = new_label
233+
234+ return A , label , features
235+
236+ def process (self ):
237+ A , label , features = self .load_twitch (self .name )
238+ edge_index = torch .tensor (np .array (A .nonzero ()), dtype = torch .long )
239+ features = np .array (features )
240+ x = torch .from_numpy (features ).to (torch .float )
241+ y = torch .from_numpy (label ).to (torch .int64 )
242+
243+ data_list = []
244+ data = Data (edge_index = edge_index , x = x , y = y )
245+
246+ random_node_indices = np .random .permutation (y .shape [0 ])
247+ training_size = int (len (random_node_indices ) * 0.8 )
248+ val_size = int (len (random_node_indices ) * 0.1 )
249+ train_node_indices = random_node_indices [:training_size ]
250+ val_node_indices = random_node_indices [training_size :training_size + val_size ]
251+ test_node_indices = random_node_indices [training_size + val_size :]
252+
253+ train_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
254+ train_masks [train_node_indices ] = 1
255+ val_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
256+ val_masks [val_node_indices ] = 1
257+ test_masks = torch .zeros ([y .shape [0 ]], dtype = torch .bool )
258+ test_masks [test_node_indices ] = 1
259+
260+ data .train_mask = train_masks
261+ data .val_mask = val_masks
262+ data .test_mask = test_masks
263+
264+ if self .pre_transform is not None :
265+ data = self .pre_transform (data )
266+
267+ data_list .append (data )
268+
269+ data , slices = self .collate ([data ])
270+
271+ torch .save ((data , slices ), self .processed_paths [0 ])
0 commit comments