This shows you the differences between two versions of the page.
Both sides previous revision Previous revision | Next revision Both sides next revision | ||
cs501r_f2018:lab9 [2018/11/12 21:22] wingated |
cs501r_f2018:lab9 [2018/11/15 23:15] wingated |
||
---|---|---|---|
Line 73: | Line 73: | ||
.... | .... | ||
- | class AdvantageDataset(Dataset): | + | class AdvantageDataset(Dataset): |
- | .... | + | def __init__(self, experience): |
- | | + | super(AdvantageDataset, self).__init__() |
- | class PolicyDataset(Dataset): | + | self._exp = experience |
- | .... | + | self._num_runs = len(experience) |
+ | self._length = reduce(lambda acc, x: acc + len(x), experience, 0) | ||
+ | |||
+ | def __getitem__(self, index): | ||
+ | idx = 0 | ||
+ | seen_data = 0 | ||
+ | current_exp = self._exp[0] | ||
+ | while seen_data + len(current_exp) - 1 < index: | ||
+ | seen_data += len(current_exp) | ||
+ | idx += 1 | ||
+ | current_exp = self._exp[idx] | ||
+ | chosen_exp = current_exp[index - seen_data] | ||
+ | return chosen_exp[0], chosen_exp[4] | ||
+ | |||
+ | def __len__(self): | ||
+ | return self._length | ||
+ | |||
+ | | ||
+ | class PolicyDataset(Dataset): | ||
+ | def __init__(self, experience): | ||
+ | super(PolicyDataset, self).__init__() | ||
+ | self._exp = experience | ||
+ | self._num_runs = len(experience) | ||
+ | self._length = reduce(lambda acc, x: acc + len(x), experience, 0) | ||
+ | |||
+ | def __getitem__(self, index): | ||
+ | idx = 0 | ||
+ | seen_data = 0 | ||
+ | current_exp = self._exp[0] | ||
+ | while seen_data + len(current_exp) - 1 < index: | ||
+ | seen_data += len(current_exp) | ||
+ | idx += 1 | ||
+ | current_exp = self._exp[idx] | ||
+ | chosen_exp = current_exp[index - seen_data] | ||
+ | return chosen_exp | ||
+ | |||
+ | def __len__(self): | ||
+ | return self._length | ||
def main(): | def main(): |