diff --git a/sponsors/forms.py b/sponsors/forms.py index 2ec0bcc54..a1fd8fcdd 100644 --- a/sponsors/forms.py +++ b/sponsors/forms.py @@ -89,11 +89,15 @@ def benefits_conflicts(self): conflicts[benefit.id] = list(benefits_conflicts) return conflicts - def get_benefits(self, cleaned_data=None): + def get_benefits(self, cleaned_data=None, include_add_ons=False): cleaned_data = cleaned_data or self.cleaned_data - return list( + benefits = list( chain(*(cleaned_data.get(bp.name) for bp in self.benefits_programs)) ) + add_ons = cleaned_data.get("add_ons_benefits") + if include_add_ons and add_ons: + benefits.extend([b for b in add_ons]) + return benefits def get_package(self): return self.cleaned_data.get("package") diff --git a/sponsors/tests/test_forms.py b/sponsors/tests/test_forms.py index 1f09262b1..0f18859c6 100644 --- a/sponsors/tests/test_forms.py +++ b/sponsors/tests/test_forms.py @@ -125,6 +125,24 @@ def test_invalid_form_if_any_conflict(self): form.errors["__all__"], ) + def test_get_benefits_from_cleaned_data(self): + benefit = self.program_1_benefits[0] + + data = {"benefits_psf": [benefit.id], + "add_ons_benefits": [b.id for b in self.add_ons]} + form = SponsorshiptBenefitsForm(data=data) + self.assertTrue(form.is_valid()) + + benefits = form.get_benefits() + self.assertEqual(1, len(benefits)) + self.assertIn(benefit, benefits) + + benefits = form.get_benefits(include_add_ons=True) + self.assertEqual(3, len(benefits)) + self.assertIn(benefit, benefits) + for add_on in self.add_ons: + self.assertIn(add_on, benefits) + def test_package_only_benefit_without_package_should_not_validate(self): SponsorshipBenefit.objects.all().update(package_only=True) diff --git a/sponsors/tests/test_views.py b/sponsors/tests/test_views.py index 3dc213b91..8a77fb7af 100644 --- a/sponsors/tests/test_views.py +++ b/sponsors/tests/test_views.py @@ -152,10 +152,15 @@ def setUp(self): self.package = baker.make("sponsors.SponsorshipPackage", advertise=True) for benefit in self.program_1_benefits: benefit.packages.add(self.package) + + # packages without associated packages + self.add_on = baker.make(SponsorshipBenefit) + self.client.cookies["sponsorship_selected_benefits"] = json.dumps( { "package": self.package.id, "benefits_psf": [b.id for b in self.program_1_benefits], + "add_ons_benefits": [self.add_on.id], } ) self.data = { @@ -176,7 +181,8 @@ def setUp(self): "web_logo": get_static_image_file_as_upload("psf-logo.png", "logo.png"), } - def test_display_template_with_form_and_context(self): + def test_display_template_with_form_and_context_without_add_ons(self): + self.add_on.delete() r = self.client.get(self.url) self.assertEqual(r.status_code, 200) @@ -193,6 +199,13 @@ def test_display_template_with_form_and_context(self): for benefit in self.program_1_benefits: self.assertIn(benefit, r.context["sponsorship_benefits"]) + def test_display_template_with_form_and_context_with_add_ons(self): + r = self.client.get(self.url) + + self.assertEqual(r.status_code, 200) + self.assertIn(self.add_on, r.context["added_benefits"]) + self.assertIsNone(r.context["sponsorship_price"]) + def test_return_package_as_none_if_not_previously_selected(self): self.client.cookies["sponsorship_selected_benefits"] = json.dumps( { @@ -277,6 +290,8 @@ def test_create_new_sponsorship(self): ) sponsorship = Sponsorship.objects.get(sponsor__name="CompanyX") self.assertTrue(sponsorship.benefits.exists()) + # 3 benefits + 1 add-on + self.assertEqual(4, sponsorship.benefits.count()) self.assertTrue(sponsorship.level_name) self.assertTrue(sponsorship.submited_by, self.user) self.assertEqual( diff --git a/sponsors/views.py b/sponsors/views.py index 0ffbda446..766649647 100644 --- a/sponsors/views.py +++ b/sponsors/views.py @@ -154,7 +154,7 @@ def form_valid(self, form): sponsorship = uc.execute( self.request.user, sponsor, - benefits_form.get_benefits(), + benefits_form.get_benefits(include_add_ons=True), benefits_form.get_package(), request=self.request, )