@@ -12,7 +12,9 @@ def _get_method_type():
12
12
class C (object ):
13
13
def x (self ):
14
14
pass
15
- return type (getattr (C , 'x' ))
15
+
16
+ return type (getattr (C , "x" ))
17
+
16
18
17
19
_method_type = _get_method_type ()
18
20
@@ -40,36 +42,39 @@ def __init__(self, **kwargs):
40
42
if cls == CSVExportView :
41
43
# We can stop checking once we hit CSVExportView.
42
44
break
43
- if hasattr (cls , ' get_fields' ) and type (getattr (cls , ' get_fields' )) == _method_type :
45
+ if hasattr (cls , " get_fields" ) and isinstance (getattr (cls , " get_fields" ), _method_type ) :
44
46
get_fields_overridden = True
45
47
break
46
48
47
49
if not get_fields_overridden :
48
50
if not self .fields and not self .exclude :
49
- raise ImproperlyConfigured (' \' fields\ ' or \ ' exclude\ ' must be specified.' )
51
+ raise ImproperlyConfigured ("' fields' or 'exclude' must be specified." )
50
52
51
53
if self .fields and self .exclude :
52
- raise ImproperlyConfigured (' Specifying both \ ' fields\ ' and \ ' exclude\ ' is not permitted.' )
54
+ raise ImproperlyConfigured (" Specifying both 'fields' and 'exclude' is not permitted." )
53
55
54
56
# Check that some special functions are not being overridden.
55
- for function_override in ('get_context_data' , 'get_paginate_by' , 'get_allow_empty' , 'get_context_object_name' ):
56
- if function_override in self .__class__ .__dict__ and \
57
- type (self .__class__ .__dict__ [function_override ]) == types .FunctionType :
58
- raise ImproperlyConfigured ('Overriding \' {}()\' is not permitted.' .format (function_override ))
57
+ for function_override in ("get_context_data" , "get_paginate_by" , "get_allow_empty" , "get_context_object_name" ):
58
+ if function_override in self .__class__ .__dict__ and isinstance (
59
+ self .__class__ .__dict__ [function_override ], types .FunctionType
60
+ ):
61
+ raise ImproperlyConfigured ("Overriding '{}()' is not permitted." .format (function_override ))
59
62
60
63
if self .paginate_by :
61
- raise ImproperlyConfigured (' \' {} \ ' does not support pagination.' .format (self .__class__ .__name__ ))
64
+ raise ImproperlyConfigured ("'{} ' does not support pagination." .format (self .__class__ .__name__ ))
62
65
63
66
if not self .allow_empty :
64
- raise ImproperlyConfigured (' \' {} \ ' does not support disabling allow_empty.' .format (self .__class__ .__name__ ))
67
+ raise ImproperlyConfigured ("'{} ' does not support disabling allow_empty." .format (self .__class__ .__name__ ))
65
68
66
69
if self .context_object_name :
67
- raise ImproperlyConfigured ('\' {}\' does not support setting context_object_name.' .format (self .__class__ .__name__ ))
70
+ raise ImproperlyConfigured (
71
+ "'{}' does not support setting context_object_name." .format (self .__class__ .__name__ )
72
+ )
68
73
69
74
def get_fields (self , queryset ):
70
- """ Override if a dynamic fields are required. """
75
+ """Override if a dynamic fields are required."""
71
76
field_names = self .fields
72
- if not field_names or field_names == ' __all__' :
77
+ if not field_names or field_names == " __all__" :
73
78
opts = queryset .model ._meta
74
79
field_names = [field .name for field in opts .fields ]
75
80
@@ -81,17 +86,17 @@ def get_fields(self, queryset):
81
86
return field_names
82
87
83
88
def get_filename (self , queryset ):
84
- """ Override if a dynamic filename is required. """
89
+ """Override if a dynamic filename is required."""
85
90
filename = self .filename
86
91
if not filename :
87
- filename = queryset .model ._meta .verbose_name_plural .replace (' ' , '-' )
92
+ filename = queryset .model ._meta .verbose_name_plural .replace (" " , "-" )
88
93
return filename
89
94
90
95
def get_field_value (self , obj , field_name ):
91
- """ Override if a custom value or behaviour is required for specific fields. """
92
- if '__' not in field_name :
93
- if hasattr (obj , ' all' ) and hasattr (obj , ' iterator' ):
94
- return ',' .join ([getattr (ro , field_name ) for ro in obj .all ()])
96
+ """Override if a custom value or behaviour is required for specific fields."""
97
+ if "__" not in field_name :
98
+ if hasattr (obj , " all" ) and hasattr (obj , " iterator" ):
99
+ return "," .join ([getattr (ro , field_name ) for ro in obj .all ()])
95
100
96
101
try :
97
102
field = obj ._meta .get_field (field_name )
@@ -103,58 +108,59 @@ def get_field_value(self, obj, field_name):
103
108
104
109
value = field .value_from_object (obj )
105
110
if field .many_to_many :
106
- return ',' .join ([force_str (ro ) for ro in value ])
111
+ return "," .join ([force_str (ro ) for ro in value ])
107
112
elif field .choices :
108
- if value is None or force_str (value ).strip () == '' :
109
- return ''
113
+ if value is None or force_str (value ).strip () == "" :
114
+ return ""
110
115
return dict (field .choices )[value ]
111
116
return field .value_from_object (obj )
112
117
else :
113
- related_field_names = field_name .split ('__' )
118
+ related_field_names = field_name .split ("__" )
114
119
related_obj = getattr (obj , related_field_names [0 ])
115
- related_field_name = '__' .join (related_field_names [1 :])
120
+ related_field_name = "__" .join (related_field_names [1 :])
116
121
return self .get_field_value (related_obj , related_field_name )
117
122
118
123
def get_header_name (self , model , field_name ):
119
- """ Override if a custom value or behaviour is required for specific fields. """
120
- if '__' not in field_name :
124
+ """Override if a custom value or behaviour is required for specific fields."""
125
+ if "__" not in field_name :
121
126
try :
122
127
field = model ._meta .get_field (field_name )
123
128
except FieldDoesNotExist as e :
124
129
if not hasattr (model , field_name ):
125
130
raise e
126
131
# field_name is a property.
127
- return field_name .replace ('_' , ' ' ).title ()
132
+ return field_name .replace ("_" , " " ).title ()
128
133
129
134
return force_str (field .verbose_name ).title ()
130
135
else :
131
- related_field_names = field_name .split ('__' )
136
+ related_field_names = field_name .split ("__" )
132
137
field = model ._meta .get_field (related_field_names [0 ])
133
- assert field .is_relation
134
- return self .get_header_name (field .related_model , '__' .join (related_field_names [1 :]))
138
+ if not field .is_relation :
139
+ raise ImproperlyConfigured (f"{ field } is not a relation" )
140
+ return self .get_header_name (field .related_model , "__" .join (related_field_names [1 :]))
135
141
136
142
def get_csv_writer_fmtparams (self ):
137
143
return {
138
- ' dialect' : ' excel' ,
139
- ' quoting' : csv .QUOTE_ALL ,
144
+ " dialect" : " excel" ,
145
+ " quoting" : csv .QUOTE_ALL ,
140
146
}
141
147
142
148
def get (self , request , * args , ** kwargs ):
143
149
queryset = self .get_queryset ()
144
150
145
151
field_names = self .get_fields (queryset )
146
152
147
- response = HttpResponse (content_type = ' text/csv' )
153
+ response = HttpResponse (content_type = " text/csv" )
148
154
149
155
filename = self .get_filename (queryset )
150
- if not filename .endswith (' .csv' ):
151
- filename += ' .csv'
152
- response [' Content-Disposition' ] = 'attachment; filename="{}"' .format (filename )
156
+ if not filename .endswith (" .csv" ):
157
+ filename += " .csv"
158
+ response [" Content-Disposition" ] = 'attachment; filename="{}"' .format (filename )
153
159
154
160
writer = csv .writer (response , ** self .get_csv_writer_fmtparams ())
155
161
156
162
if self .specify_separator :
157
- response .write (' sep={}{}' .format (writer .dialect .delimiter , writer .dialect .lineterminator ))
163
+ response .write (" sep={}{}" .format (writer .dialect .delimiter , writer .dialect .lineterminator ))
158
164
159
165
if self .header :
160
166
writer .writerow ([self .get_header_name (queryset .model , field_name ) for field_name in list (field_names )])
0 commit comments