@@ -19,13 +19,27 @@ class UniformSampler(Sampler):
1919 """
2020
2121 def __init__ (
22- self , min_value : Union [int , float ], max_value : Union [int , float ], ** kwargs
22+ self ,
23+ min_value : Union [int , float ],
24+ max_value : Union [int , float ],
25+ seed : Optional [int ] = None ,
26+ ** kwargs
2327 ) -> None :
28+ """
29+ :param min_value: minimum value of the range to be sampled uniformly from
30+ :param max_value: maximum value of the range to be sampled uniformly from
31+ :param seed: Random seed used for making draws from the uniform sampler
32+ """
2433 self .min_value = min_value
2534 self .max_value = max_value
35+ # Draw from random state to allow for consistent reset parameter draw for a seed
36+ self .random_state = np .random .RandomState (seed )
2637
2738 def sample_parameter (self ) -> float :
28- return np .random .uniform (self .min_value , self .max_value )
39+ """
40+ Draws and returns a sample from the specified interval
41+ """
42+ return self .random_state .uniform (self .min_value , self .max_value )
2943
3044
3145class MultiRangeUniformSampler (Sampler ):
@@ -36,19 +50,33 @@ class MultiRangeUniformSampler(Sampler):
3650 it proceeds to pick a value uniformly in that range.
3751 """
3852
39- def __init__ (self , intervals : List [List [Union [int , float ]]], ** kwargs ) -> None :
53+ def __init__ (
54+ self ,
55+ intervals : List [List [Union [int , float ]]],
56+ seed : Optional [int ] = None ,
57+ ** kwargs
58+ ) -> None :
59+ """
60+ :param intervals: List of intervals to draw uniform samples from
61+ :param seed: Random seed used for making uniform draws from the specified intervals
62+ """
4063 self .intervals = intervals
4164 # Measure the length of the intervals
4265 interval_lengths = [abs (x [1 ] - x [0 ]) for x in self .intervals ]
4366 cum_interval_length = sum (interval_lengths )
4467 # Assign weights to an interval proportionate to the interval size
4568 self .interval_weights = [x / cum_interval_length for x in interval_lengths ]
69+ # Draw from random state to allow for consistent reset parameter draw for a seed
70+ self .random_state = np .random .RandomState (seed )
4671
4772 def sample_parameter (self ) -> float :
73+ """
74+ Selects an interval to pick and then draws a uniform sample from the picked interval
75+ """
4876 cur_min , cur_max = self .intervals [
49- np . random .choice (len (self .intervals ), p = self .interval_weights )
77+ self . random_state .choice (len (self .intervals ), p = self .interval_weights )
5078 ]
51- return np . random .uniform (cur_min , cur_max )
79+ return self . random_state .uniform (cur_min , cur_max )
5280
5381
5482class GaussianSampler (Sampler ):
@@ -58,13 +86,27 @@ class GaussianSampler(Sampler):
5886 """
5987
6088 def __init__ (
61- self , mean : Union [float , int ], st_dev : Union [float , int ], ** kwargs
89+ self ,
90+ mean : Union [float , int ],
91+ st_dev : Union [float , int ],
92+ seed : Optional [int ] = None ,
93+ ** kwargs
6294 ) -> None :
95+ """
96+ :param mean: Specifies the mean of the gaussian distribution to draw from
97+ :param st_dev: Specifies the standard devation of the gaussian distribution to draw from
98+ :param seed: Random seed used for making gaussian draws from the sample
99+ """
63100 self .mean = mean
64101 self .st_dev = st_dev
102+ # Draw from random state to allow for consistent reset parameter draw for a seed
103+ self .random_state = np .random .RandomState (seed )
65104
66105 def sample_parameter (self ) -> float :
67- return np .random .normal (self .mean , self .st_dev )
106+ """
107+ Returns a draw from the specified Gaussian distribution
108+ """
109+ return self .random_state .normal (self .mean , self .st_dev )
68110
69111
70112class SamplerFactory :
@@ -81,17 +123,31 @@ class SamplerFactory:
81123
82124 @staticmethod
83125 def register_sampler (name : str , sampler_cls : Type [Sampler ]) -> None :
126+ """
127+ Registers the sampe in the Sampler Factory to be used later
128+ :param name: String name to set as key for the sampler_cls in the factory
129+ :param sampler_cls: Sampler object to associate to the name in the factory
130+ """
84131 SamplerFactory .NAME_TO_CLASS [name ] = sampler_cls
85132
86133 @staticmethod
87- def init_sampler_class (name : str , params : Dict [str , Any ]):
134+ def init_sampler_class (
135+ name : str , params : Dict [str , Any ], seed : Optional [int ] = None
136+ ) -> Sampler :
137+ """
138+ Initializes the sampler class associated with the name with the params
139+ :param name: Name of the sampler in the factory to initialize
140+ :param params: Parameters associated to the sampler attached to the name
141+ :param seed: Random seed to be used to set deterministic random draws for the sampler
142+ """
88143 if name not in SamplerFactory .NAME_TO_CLASS :
89144 raise SamplerException (
90145 name + " sampler is not registered in the SamplerFactory."
91146 " Use the register_sample method to register the string"
92147 " associated to your sampler in the SamplerFactory."
93148 )
94149 sampler_cls = SamplerFactory .NAME_TO_CLASS [name ]
150+ params ["seed" ] = seed
95151 try :
96152 return sampler_cls (** params )
97153 except TypeError :
@@ -103,7 +159,13 @@ def init_sampler_class(name: str, params: Dict[str, Any]):
103159
104160
105161class SamplerManager :
106- def __init__ (self , reset_param_dict : Dict [str , Any ]) -> None :
162+ def __init__ (
163+ self , reset_param_dict : Dict [str , Any ], seed : Optional [int ] = None
164+ ) -> None :
165+ """
166+ :param reset_param_dict: Arguments needed for initializing the samplers
167+ :param seed: Random seed to be used for drawing samples from the samplers
168+ """
107169 self .reset_param_dict = reset_param_dict if reset_param_dict else {}
108170 assert isinstance (self .reset_param_dict , dict )
109171 self .samplers : Dict [str , Sampler ] = {}
@@ -116,7 +178,7 @@ def __init__(self, reset_param_dict: Dict[str, Any]) -> None:
116178 )
117179 sampler_name = cur_param_dict .pop ("sampler-type" )
118180 param_sampler = SamplerFactory .init_sampler_class (
119- sampler_name , cur_param_dict
181+ sampler_name , cur_param_dict , seed
120182 )
121183
122184 self .samplers [param_name ] = param_sampler
@@ -128,6 +190,10 @@ def is_empty(self) -> bool:
128190 return not bool (self .samplers )
129191
130192 def sample_all (self ) -> Dict [str , float ]:
193+ """
194+ Loop over all samplers and draw a sample from each one for generating
195+ next set of reset parameter values.
196+ """
131197 res = {}
132198 for param_name , param_sampler in list (self .samplers .items ()):
133199 res [param_name ] = param_sampler .sample_parameter ()
0 commit comments