Coverage for mongo/problem/problem.py: 89%
238 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 03:01 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 03:01 +0000
1import json
2import enum
3from datetime import datetime, timedelta
4from typing import (
5 Any,
6 BinaryIO,
7 Dict,
8 List,
9 Optional,
10)
11from dataclasses import dataclass
12from io import BytesIO
13from ulid import ULID
14from .. import engine
15from ..base import MongoBase
16from ..course import *
17from ..utils import (RedisCache, doc_required, drop_none, MinioClient)
18from ..user import User
19from .exception import BadTestCase
20from .test_case import (
21 SimpleIO,
22 ContextIO,
23 IncludeDirectory,
24 TestCaseRule,
25)
27__all__ = ('Problem', )
30@dataclass
31class UploadInfo:
32 urls: List[str]
33 upload_id: str
36class Problem(MongoBase, engine=engine.Problem):
38 class Permission(enum.IntFlag):
39 VIEW = enum.auto() # user view permission
40 ONLINE = enum.auto() # user can view problem or not
41 MANAGE = enum.auto() # user manage problem permission
43 def detailed_info(self, *ks, **kns) -> Dict[str, Any]:
44 '''
45 return detailed info about this problem. notice
46 that the `input` and `output` of problem test
47 case won't be sent to front end, need call other
48 route to get this info.
50 Args:
51 ks (*str): the field name you want to get
52 kns (**[str, str]):
53 specify the dict key you want to store
54 the data get by field name
55 Return:
56 a dict contains problem's data
57 '''
58 if not self:
59 return {}
60 # problem -> dict
61 _ret = self.to_mongo()
62 # preprocess fields
63 # case zip can not be serialized
64 if 'caseZip' in _ret['testCase']:
65 del _ret['testCase']['caseZip']
66 # skip minio path
67 if 'caseZipMinioPath' in _ret['testCase']:
68 del _ret['testCase']['caseZipMinioPath']
69 # convert couse document to course name
70 _ret['courses'] = [course.course_name for course in self.courses]
71 ret = {}
72 for k in ks:
73 kns[k] = k
74 for k, n in kns.items():
75 s_ns = n.split('__')
76 # extract wanted value
77 v = _ret[s_ns[0]]
78 for s_n in s_ns[1:]:
79 v = v[s_n]
80 # extract wanted keys
81 e = ret
82 s_ks = k.split('__')
83 for s_k in s_ks[:-1]:
84 if s_k not in e:
85 e[s_k] = {}
86 e = e[s_k]
87 e[s_ks[-1]] = v
88 return ret
90 def allowed(self, language):
91 if self.problem_type == 2:
92 return True
93 if language >= 3 or language < 0:
94 return False
95 return bool((1 << language) & self.allowed_language)
97 def submit_count(self, user: User) -> int:
98 '''
99 Calculate how many submissions the user has submitted to this problem.
100 '''
101 # reset quota if it's a new day
102 if user.last_submit.date() != datetime.now().date():
103 user.update(problem_submission={})
104 return 0
105 return user.problem_submission.get(str(self.problem_id), 0)
107 def running_homeworks(self) -> List:
108 from ..homework import Homework
109 now = datetime.now()
110 return [Homework(hw.id) for hw in self.homeworks if now in hw.duration]
112 def is_valid_ip(self, ip: str):
113 return all(hw.is_valid_ip(ip) for hw in self.running_homeworks())
115 def get_submission_status(self) -> Dict[str, int]:
116 pipeline = {
117 "$group": {
118 "_id": "$status",
119 "count": {
120 "$sum": 1
121 },
122 }
123 }
124 cursor = engine.Submission.objects(problem=self.id).aggregate(
125 [pipeline], )
126 return {item['_id']: item['count'] for item in cursor}
128 def get_ac_user_count(self) -> int:
129 ac_users = engine.Submission.objects(
130 problem=self.id,
131 status=0,
132 ).distinct('user')
133 return len(ac_users)
135 def get_tried_user_count(self) -> int:
136 tried_users = engine.Submission.objects(
137 problem=self.id, ).distinct('user')
138 return len(tried_users)
140 @doc_required('user', User)
141 def high_score_key(self, user: User) -> str:
142 return f'PROBLEM_{self.id}_{user.id}_HIGH_SCORE'
144 @doc_required('user', User)
145 def get_high_score(self, user: User) -> int:
146 '''
147 Get highest score for user of this problem.
148 '''
149 cache = RedisCache()
150 key = self.high_score_key(user=user)
151 if (val := cache.get(key)) is not None:
152 return int(val.decode())
153 # TODO: avoid calling mongoengine API directly
154 submissions = engine.Submission.objects(
155 user=user.id,
156 problem=self.id,
157 ).only('score').order_by('-score').limit(1)
158 if submissions.count() == 0:
159 high_score = 0
160 else:
161 # It might < 0 if there is only incomplete submission
162 high_score = max(submissions[0].score, 0)
163 cache.set(key, high_score, ex=600)
164 return high_score
166 @doc_required('user', User)
167 def own_permission(self, user: User) -> Permission:
168 """
169 generate user permission capability
170 """
172 user_cap = self.Permission(0)
173 for course in map(Course, self.courses):
174 # inherit course permission
175 if course.permission(user, Course.Permission.VIEW):
176 user_cap |= self.Permission.VIEW
178 # online problem
179 if self.problem_status == 0:
180 check_public_problem = True
181 for homework in course.homeworks:
182 if self.problem_id in homework.problem_ids:
183 check_public_problem = False
184 # current time after homework then online problem
185 if datetime.now() >= homework.duration.start:
186 user_cap |= self.Permission.ONLINE
188 # problem does not belong to any homework
189 if check_public_problem:
190 user_cap |= self.Permission.ONLINE
192 # Admin, Teacher && is owner
193 if user.role == 0 or self.owner == user.username:
194 user_cap |= self.Permission.VIEW
195 user_cap |= self.Permission.ONLINE
196 user_cap |= self.Permission.MANAGE
198 return user_cap
200 def permission(self, user: User, req: Permission) -> bool:
201 """
202 check whether user own `req` permission
203 """
205 return (self.own_permission(user=user) & req) == req
207 @classmethod
208 def get_problem_list(
209 cls,
210 user,
211 offset: int = 0,
212 count: int = -1,
213 problem_id: int = None,
214 name: str = None,
215 tags: list = None,
216 course: str = None,
217 ):
218 '''
219 get a list of problems
220 '''
221 if course is not None:
222 course = Course(course)
223 if not course:
224 return []
225 course = course.obj
226 # qurey args
227 ks = drop_none({
228 'problem_id': problem_id,
229 'problem_name': name,
230 'courses': course,
231 'tags__in': tags,
232 })
233 problems = [
234 p for p in engine.Problem.objects(**ks).order_by('problemId')
235 if cls(p).permission(user=user, req=cls.Permission.ONLINE)
236 ]
237 # truncate
238 if offset < 0 or (offset >= len(problems) and len(problems)):
239 raise IndexError
240 right = len(problems) if count < 0 else offset + count
241 right = min(len(problems), right)
242 return problems[offset:right]
244 @classmethod
245 def add(
246 cls,
247 user: User,
248 courses: List[str],
249 problem_name: str,
250 status: Optional[int] = None,
251 description: Optional[Dict[str, Any]] = None,
252 tags: Optional[List[str]] = None,
253 type: Optional[int] = None,
254 test_case_info: Optional[Dict[str, Any]] = None,
255 can_view_stdout: bool = False,
256 allowed_language: Optional[int] = None,
257 quota: Optional[int] = None,
258 default_code: Optional[str] = None,
259 ):
260 if len(courses) == 0:
261 raise ValueError('No course provided')
262 course_objs = []
263 for course in map(Course, courses):
264 if not course:
265 raise engine.DoesNotExist
266 course_objs.append(course.id)
267 problem_args = drop_none({
268 'courses': course_objs,
269 'problem_status': status,
270 'problem_type': type,
271 'problem_name': problem_name,
272 'description': description,
273 'owner': user.username,
274 'tags': tags,
275 'quota': quota,
276 'default_code': default_code,
277 })
278 problem = cls.engine(**problem_args).save()
279 programming_problem_args = drop_none({
280 'test_case':
281 test_case_info,
282 'can_view_stdout':
283 can_view_stdout,
284 'allowed_language':
285 allowed_language,
286 })
287 if programming_problem_args and type != 2:
288 problem.update(**programming_problem_args)
289 return problem.problem_id
291 @classmethod
292 def edit_problem(
293 cls,
294 user: User,
295 problem_id: int,
296 courses: List[str],
297 status: int,
298 problem_name: str,
299 description: Dict[str, Any],
300 tags: List[str],
301 type,
302 test_case_info: Optional[Dict[str, Any]] = None,
303 allowed_language: int = 7,
304 can_view_stdout: bool = False,
305 quota: int = -1,
306 default_code: str = '',
307 ):
308 if type != 2:
309 score = sum(t['taskScore'] for t in test_case_info['tasks'])
310 if score != 100:
311 raise ValueError("Cases' scores should be 100 in total")
312 problem = Problem(problem_id).obj
313 course_objs = []
314 for name in courses:
315 if not (course := Course(name)):
316 raise engine.DoesNotExist
317 course_objs.append(course.obj)
318 problem.update(
319 courses=course_objs,
320 problem_status=status,
321 problem_type=type,
322 problem_name=problem_name,
323 description=description,
324 owner=user.username,
325 tags=tags,
326 quota=quota,
327 default_code=default_code,
328 )
329 if type != 2:
330 # preprocess test case
331 test_case = problem.test_case
332 if test_case_info:
333 test_case = engine.ProblemTestCase.from_json(
334 json.dumps(test_case_info))
335 test_case.case_zip = problem.test_case.case_zip
336 test_case.case_zip_minio_path = problem.test_case.case_zip_minio_path
337 problem.update(
338 allowed_language=allowed_language,
339 can_view_stdout=can_view_stdout,
340 test_case=test_case,
341 )
343 def update_test_case(self, test_case: BinaryIO):
344 '''
345 edit problem's testcase
347 Args:
348 test_case: testcase zip file
349 Exceptions:
350 zipfile.BadZipFile: if `test_case` is not a zip file
351 ValueError: if test case is None or problem_id is invalid
352 engine.DoesNotExist
353 '''
354 self._validate_test_case(test_case)
355 test_case.seek(0)
356 self._save_test_case_zip(test_case)
358 def _save_test_case_zip(self, test_case: BinaryIO):
359 '''
360 save test case zip file
361 '''
362 minio_client = MinioClient()
363 path = self._generate_test_case_obj_path()
364 minio_client.client.put_object(
365 minio_client.bucket,
366 path,
367 test_case,
368 -1,
369 part_size=5 * 1024 * 1024,
370 content_type='application/zip',
371 )
372 self.update(test_case__case_zip_minio_path=path)
373 self.reload('test_case')
375 def _generate_test_case_obj_path(self):
376 return f'problem-test-case/{ULID()}.zip'
378 def _validate_test_case(self, test_case: BinaryIO):
379 '''
380 validate test case, raise BadTestCase if invalid
381 '''
382 rules: List[TestCaseRule] = [
383 IncludeDirectory(self, 'include'),
384 IncludeDirectory(self, 'share'),
385 # for backward compatibility
386 IncludeDirectory(self, 'chaos'),
387 ]
388 for rule in rules:
389 rule.validate(test_case)
391 # Should only match one format
392 rules = [
393 SimpleIO(self, ['include/', 'share/', 'chaos/']),
394 ContextIO(self),
395 ]
396 excs = []
397 for rule in rules:
398 try:
399 rule.validate(test_case)
400 except BadTestCase as e:
401 excs.append(e)
403 if len(excs) == 0:
404 raise BadTestCase('ambiguous test case format')
405 elif len(excs) == 2:
406 raise BadTestCase(
407 f'invalid test case format\n\n{excs[0]}\n\n{excs[1]}')
409 @classmethod
410 def copy_problem(cls, user, problem_id):
411 problem = Problem(problem_id).obj
412 engine.Problem(
413 problem_status=problem.problem_status,
414 problem_type=problem.problem_type,
415 problem_name=problem.problem_name,
416 description=problem.description,
417 owner=user.username,
418 tags=problem.tags,
419 test_case=problem.test_case,
420 ).save()
422 @doc_required('target', Course, src_none_allowed=True)
423 def copy_to(
424 self,
425 user: User,
426 target: Optional[Course] = None,
427 **override,
428 ) -> 'Problem':
429 '''
430 Copy a problem to target course, hidden by default.
432 Args:
433 user (User): The user who execute this action and will become
434 the owner of copied problem.
435 target (Optional[Course] = None): The course this problem will
436 be copied to, default to the first of origial courses.
437 override: Override field values passed to `Problem.add`.
438 '''
439 target = self.courses[0] if target is None else target
440 # Copied problem is hidden by default
441 status = override.pop('status', Problem.engine.Visibility.HIDDEN)
442 ks = dict(
443 user=user,
444 courses=[target.course_name],
445 problem_name=self.problem_name,
446 status=status,
447 description=self.description.to_mongo(),
448 tags=self.tags,
449 type=self.problem_type,
450 test_case_info=self.test_case.to_mongo(),
451 can_view_stdout=self.can_view_stdout,
452 allowed_language=self.allowed_language,
453 quota=self.quota,
454 default_code=self.default_code,
455 )
456 ks.update(override)
457 copy = self.add(**ks)
458 return copy
460 @classmethod
461 def release_problem(cls, problem_id):
462 course = Course('Public').obj
463 problem = Problem(problem_id).obj
464 problem.courses = [course]
465 problem.owner = 'first_admin'
466 problem.save()
468 def is_test_case_ready(self) -> bool:
469 return (self.test_case.case_zip.grid_id is not None
470 or self.test_case.case_zip_minio_path is not None)
472 def get_test_case(self) -> BinaryIO:
473 if self.test_case.case_zip_minio_path is not None:
474 minio_client = MinioClient()
475 try:
476 resp = minio_client.client.get_object(
477 minio_client.bucket,
478 self.test_case.case_zip_minio_path,
479 )
480 return BytesIO(resp.read())
481 finally:
482 if 'resp' in locals():
483 resp.close()
484 resp.release_conn()
486 # fallback to legacy GridFS storage
487 return self.test_case.case_zip
489 # TODO: hope minio SDK to provide more high-level API
490 def generate_urls_for_uploading_test_case(
491 self,
492 length: int,
493 part_size: int,
494 ) -> UploadInfo:
495 # TODO: update url after uploading completed
496 # TODO: handle failed uploading
497 path = self._generate_test_case_obj_path()
498 self.update(test_case__case_zip_minio_path=path)
500 minio_client = MinioClient()
501 upload_id = minio_client.client._create_multipart_upload(
502 minio_client.bucket,
503 path,
504 headers={'Content-Type': 'application/zip'},
505 )
506 part_count = (length + part_size - 1) // part_size
508 def get(i: int):
509 return minio_client.client.get_presigned_url(
510 'PUT',
511 minio_client.bucket,
512 path,
513 expires=timedelta(minutes=30),
514 extra_query_params={
515 'partNumber': str(i + 1),
516 'uploadId': upload_id
517 },
518 )
520 return UploadInfo(
521 urls=[get(i) for i in range(part_count)],
522 upload_id=upload_id,
523 )
525 def complete_test_case_upload(self, upload_id: str, parts: list):
526 minio_client = MinioClient()
527 minio_client.client._complete_multipart_upload(
528 minio_client.bucket,
529 self.test_case.case_zip_minio_path,
530 upload_id,
531 parts,
532 )
534 try:
535 test_case = self.get_test_case()
536 self._validate_test_case(test_case)
537 except BadTestCase:
538 self.update(test_case__case_zip_minio_path=None)
539 raise