User Tools

Site Tools


cs501r_f2018:lab9

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revision Previous revision
Next revision Both sides next revision
cs501r_f2018:lab9 [2018/11/12 21:22]
wingated
cs501r_f2018:lab9 [2018/11/15 23:15]
wingated
Line 73: Line 73:
     ....     ....
  
-class AdvantageDataset(Dataset):​ +class AdvantageDataset(Dataset): ​                                                                                                                    
-    .... +    ​def __init__(self,​ experience): ​                                                                                                                 
-     ​ +        super(AdvantageDataset,​ self).__init__() ​                                                                                                    
-class PolicyDataset(Dataset):​ +        self._exp = experience ​                                                                                                                      
-    ....+        self._num_runs = len(experience) ​                                                                                                            
 +        self._length = reduce(lambda acc, x: acc + len(x), experience, 0)                                                                            
 +                                                                                                                                                     
 +    def __getitem__(self,​ index): ​                                                                                                                   
 +        idx = 0                                                                                                                                      
 +        seen_data = 0                                                                                                                                
 +        current_exp = self._exp[0] ​                                                                                                                  
 +        while seen_data + len(current_exp) - 1 < index: ​                                                                                             
 +            seen_data += len(current_exp) ​                                                                                                           
 +            idx += 1                                                                                                                                 
 +            current_exp = self._exp[idx] ​                                                                                                            
 +        chosen_exp = current_exp[index - seen_data] ​                                                                                                 
 +        return chosen_exp[0],​ chosen_exp[4] ​                                                                                                         
 +                                                                                                                                                     
 +    def __len__(self): ​                                                                                                                              
 +        return self._length ​                                                                                                                         
 +                                                                                                                                                     
 +                                                                                                                                                    ​ 
 +class PolicyDataset(Dataset): ​                                                                                                                      ​ 
 +    ​def __init__(self,​ experience): ​                                                                                                                 
 +        super(PolicyDataset,​ self).__init__() ​                                                                                                       
 +        self._exp = experience ​                                                                                                                      
 +        self._num_runs = len(experience) ​                                                                                                            
 +        self._length = reduce(lambda acc, x: acc + len(x), experience, 0)                                                                            
 +                                                                                                                                                     
 +    def __getitem__(self,​ index): ​                                                                                                                   
 +        idx = 0                                                                                                                                      
 +        seen_data = 0                                                                                                                                
 +        current_exp = self._exp[0] ​                                                                                                                  
 +        while seen_data + len(current_exp) - 1 < index: ​                                                                                             
 +            seen_data += len(current_exp) ​                                                                                                           
 +            idx += 1                                                                                                                                 
 +            current_exp = self._exp[idx] ​                                                                                                            
 +        chosen_exp = current_exp[index - seen_data] ​                                                                                                 
 +        return chosen_exp ​                                                                                                                           
 +                                                                                                                                                     
 +    def __len__(self): ​                                                                                                                              
 +        return self._length ​                                                                                                                         
  
 def main(): def main():
cs501r_f2018/lab9.txt · Last modified: 2021/06/30 23:42 (external edit)