Coverage for mongo/problem/problem.py: 79%
272 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:37 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-11 18:37 +0000
1import json
2import enum
3from hashlib import md5
4from datetime import datetime, timedelta
5from typing import (
6 Any,
7 BinaryIO,
8 Dict,
9 List,
10 Optional,
11)
12from dataclasses import dataclass
13from io import BytesIO
14from ulid import ULID
15from .. import engine
16from ..base import MongoBase
17from ..course import *
18from ..utils import (RedisCache, doc_required, drop_none, MinioClient)
19from ..user import User
20from .exception import BadTestCase
21from .test_case import (
22 SimpleIO,
23 ContextIO,
24 IncludeDirectory,
25 TestCaseRule,
26)
28__all__ = ('Problem', )
31@dataclass
32class UploadInfo:
33 urls: List[str]
34 upload_id: str
37class Problem(MongoBase, engine=engine.Problem):
39 class Permission(enum.IntFlag):
40 VIEW = enum.auto() # user view permission
41 ONLINE = enum.auto() # user can view problem or not
42 MANAGE = enum.auto() # user manage problem permission
44 def detailed_info(self, *ks, **kns) -> Dict[str, Any]:
45 '''
46 return detailed info about this problem. notice
47 that the `input` and `output` of problem test
48 case won't be sent to front end, need call other
49 route to get this info.
51 Args:
52 ks (*str): the field name you want to get
53 kns (**[str, str]):
54 specify the dict key you want to store
55 the data get by field name
56 Return:
57 a dict contains problem's data
58 '''
59 if not self:
60 return {}
61 # problem -> dict
62 _ret = self.to_mongo()
63 # preprocess fields
64 # case zip can not be serialized
65 if 'caseZip' in _ret['testCase']:
66 del _ret['testCase']['caseZip']
67 # skip minio path
68 if 'caseZipMinioPath' in _ret['testCase']:
69 del _ret['testCase']['caseZipMinioPath']
70 # convert couse document to course name
71 _ret['courses'] = [course.course_name for course in self.courses]
72 ret = {}
73 for k in ks:
74 kns[k] = k
75 for k, n in kns.items():
76 s_ns = n.split('__')
77 # extract wanted value
78 v = _ret[s_ns[0]]
79 for s_n in s_ns[1:]:
80 v = v[s_n]
81 # extract wanted keys
82 e = ret
83 s_ks = k.split('__')
84 for s_k in s_ks[:-1]:
85 if s_k not in e:
86 e[s_k] = {}
87 e = e[s_k]
88 e[s_ks[-1]] = v
89 return ret
91 def allowed(self, language):
92 if self.problem_type == 2:
93 return True
94 if language >= 3 or language < 0:
95 return False
96 return bool((1 << language) & self.allowed_language)
98 def submit_count(self, user: User) -> int:
99 '''
100 Calculate how many submissions the user has submitted to this problem.
101 '''
102 # reset quota if it's a new day
103 if user.last_submit.date() != datetime.now().date():
104 user.update(problem_submission={})
105 return 0
106 return user.problem_submission.get(str(self.problem_id), 0)
108 def running_homeworks(self) -> List:
109 from ..homework import Homework
110 now = datetime.now()
111 return [Homework(hw.id) for hw in self.homeworks if now in hw.duration]
113 def is_valid_ip(self, ip: str):
114 return all(hw.is_valid_ip(ip) for hw in self.running_homeworks())
116 def get_submission_status(self) -> Dict[str, int]:
117 pipeline = {
118 "$group": {
119 "_id": "$status",
120 "count": {
121 "$sum": 1
122 },
123 }
124 }
125 cursor = engine.Submission.objects(problem=self.id).aggregate(
126 [pipeline], )
127 return {item['_id']: item['count'] for item in cursor}
129 def get_ac_user_count(self) -> int:
130 ac_users = engine.Submission.objects(
131 problem=self.id,
132 status=0,
133 ).distinct('user')
134 return len(ac_users)
136 def get_tried_user_count(self) -> int:
137 tried_users = engine.Submission.objects(
138 problem=self.id, ).distinct('user')
139 return len(tried_users)
141 @doc_required('user', User)
142 def high_score_key(self, user: User) -> str:
143 return f'PROBLEM_{self.id}_{user.id}_HIGH_SCORE'
145 @doc_required('user', User)
146 def get_high_score(self, user: User) -> int:
147 '''
148 Get highest score for user of this problem.
149 '''
150 cache = RedisCache()
151 key = self.high_score_key(user=user)
152 if (val := cache.get(key)) is not None:
153 return int(val.decode())
154 # TODO: avoid calling mongoengine API directly
155 submissions = engine.Submission.objects(
156 user=user.id,
157 problem=self.id,
158 ).only('score').order_by('-score').limit(1)
159 if submissions.count() == 0:
160 high_score = 0
161 else:
162 # It might < 0 if there is only incomplete submission
163 high_score = max(submissions[0].score, 0)
164 cache.set(key, high_score, ex=600)
165 return high_score
167 @doc_required('user', User)
168 def own_permission(self, user: User) -> Permission:
169 """
170 generate user permission capability
171 """
173 user_cap = self.Permission(0)
174 for course in map(Course, self.courses):
175 # inherit course permission
176 if course.permission(user, Course.Permission.VIEW):
177 user_cap |= self.Permission.VIEW
179 # online problem
180 if self.problem_status == 0:
181 check_public_problem = True
182 for homework in course.homeworks:
183 if self.problem_id in homework.problem_ids:
184 check_public_problem = False
185 # current time after homework then online problem
186 if datetime.now() >= homework.duration.start:
187 user_cap |= self.Permission.ONLINE
189 # problem does not belong to any homework
190 if check_public_problem:
191 user_cap |= self.Permission.ONLINE
193 # Admin, Teacher && is owner
194 if user.role == 0 or self.owner == user.username:
195 user_cap |= self.Permission.VIEW
196 user_cap |= self.Permission.ONLINE
197 user_cap |= self.Permission.MANAGE
199 return user_cap
201 def permission(self, user: User, req: Permission) -> bool:
202 """
203 check whether user own `req` permission
204 """
206 return (self.own_permission(user=user) & req) == req
208 @classmethod
209 def get_problem_list(
210 cls,
211 user,
212 offset: int = 0,
213 count: int = -1,
214 problem_id: int = None,
215 name: str = None,
216 tags: list = None,
217 course: str = None,
218 ):
219 '''
220 get a list of problems
221 '''
222 if course is not None:
223 course = Course(course)
224 if not course:
225 return []
226 course = course.obj
227 # qurey args
228 ks = drop_none({
229 'problem_id': problem_id,
230 'problem_name': name,
231 'courses': course,
232 'tags__in': tags,
233 })
234 problems = [
235 p for p in engine.Problem.objects(**ks).order_by('problemId')
236 if cls(p).permission(user=user, req=cls.Permission.ONLINE)
237 ]
238 # truncate
239 if offset < 0 or (offset >= len(problems) and len(problems)):
240 raise IndexError
241 right = len(problems) if count < 0 else offset + count
242 right = min(len(problems), right)
243 return problems[offset:right]
245 @classmethod
246 def add(
247 cls,
248 user: User,
249 courses: List[str],
250 problem_name: str,
251 status: Optional[int] = None,
252 description: Optional[Dict[str, Any]] = None,
253 tags: Optional[List[str]] = None,
254 type: Optional[int] = None,
255 test_case_info: Optional[Dict[str, Any]] = None,
256 can_view_stdout: bool = False,
257 allowed_language: Optional[int] = None,
258 quota: Optional[int] = None,
259 default_code: Optional[str] = None,
260 ):
261 if len(courses) == 0:
262 raise ValueError('No course provided')
263 course_objs = []
264 for course in map(Course, courses):
265 if not course:
266 raise engine.DoesNotExist
267 course_objs.append(course.id)
268 problem_args = drop_none({
269 'courses': course_objs,
270 'problem_status': status,
271 'problem_type': type,
272 'problem_name': problem_name,
273 'description': description,
274 'owner': user.username,
275 'tags': tags,
276 'quota': quota,
277 'default_code': default_code,
278 })
279 problem = cls.engine(**problem_args).save()
280 programming_problem_args = drop_none({
281 'test_case':
282 test_case_info,
283 'can_view_stdout':
284 can_view_stdout,
285 'allowed_language':
286 allowed_language,
287 })
288 if programming_problem_args and type != 2:
289 problem.update(**programming_problem_args)
290 return problem.problem_id
292 @classmethod
293 def edit_problem(
294 cls,
295 user: User,
296 problem_id: int,
297 courses: List[str],
298 status: int,
299 problem_name: str,
300 description: Dict[str, Any],
301 tags: List[str],
302 type,
303 test_case_info: Optional[Dict[str, Any]] = None,
304 allowed_language: int = 7,
305 can_view_stdout: bool = False,
306 quota: int = -1,
307 default_code: str = '',
308 ):
309 if type != 2:
310 score = sum(t['taskScore'] for t in test_case_info['tasks'])
311 if score != 100:
312 raise ValueError("Cases' scores should be 100 in total")
313 problem = Problem(problem_id).obj
314 course_objs = []
315 for name in courses:
316 if not (course := Course(name)):
317 raise engine.DoesNotExist
318 course_objs.append(course.obj)
319 problem.update(
320 courses=course_objs,
321 problem_status=status,
322 problem_type=type,
323 problem_name=problem_name,
324 description=description,
325 owner=user.username,
326 tags=tags,
327 quota=quota,
328 default_code=default_code,
329 )
330 if type != 2:
331 # preprocess test case
332 test_case = problem.test_case
333 if test_case_info:
334 test_case = engine.ProblemTestCase.from_json(
335 json.dumps(test_case_info))
336 test_case.case_zip = problem.test_case.case_zip
337 test_case.case_zip_minio_path = problem.test_case.case_zip_minio_path
338 problem.update(
339 allowed_language=allowed_language,
340 can_view_stdout=can_view_stdout,
341 test_case=test_case,
342 )
344 def update_test_case(self, test_case: BinaryIO):
345 '''
346 edit problem's testcase
348 Args:
349 test_case: testcase zip file
350 Exceptions:
351 zipfile.BadZipFile: if `test_case` is not a zip file
352 ValueError: if test case is None or problem_id is invalid
353 engine.DoesNotExist
354 '''
355 self._validate_test_case(test_case)
356 test_case.seek(0)
357 self._save_test_case_zip(test_case)
359 def _save_test_case_zip(self, test_case: BinaryIO):
360 '''
361 save test case zip file
362 '''
363 minio_client = MinioClient()
364 path = self._generate_test_case_obj_path()
365 minio_client.client.put_object(
366 minio_client.bucket,
367 path,
368 test_case,
369 -1,
370 part_size=5 * 1024 * 1024,
371 content_type='application/zip',
372 )
373 self.update(test_case__case_zip_minio_path=path)
374 self.reload('test_case')
376 def _generate_test_case_obj_path(self):
377 return f'problem-test-case/{ULID()}.zip'
379 def _validate_test_case(self, test_case: BinaryIO):
380 '''
381 validate test case, raise BadTestCase if invalid
382 '''
383 rules: List[TestCaseRule] = [
384 IncludeDirectory(self, 'include'),
385 IncludeDirectory(self, 'share'),
386 # for backward compatibility
387 IncludeDirectory(self, 'chaos'),
388 ]
389 for rule in rules:
390 rule.validate(test_case)
392 # Should only match one format
393 rules = [
394 SimpleIO(self, ['include/', 'share/', 'chaos/']),
395 ContextIO(self),
396 ]
397 excs = []
398 for rule in rules:
399 try:
400 rule.validate(test_case)
401 except BadTestCase as e:
402 excs.append(e)
404 if len(excs) == 0:
405 raise BadTestCase('ambiguous test case format')
406 elif len(excs) == 2:
407 raise BadTestCase(
408 f'invalid test case format\n\n{excs[0]}\n\n{excs[1]}')
410 @classmethod
411 def copy_problem(cls, user, problem_id):
412 problem = Problem(problem_id).obj
413 engine.Problem(
414 problem_status=problem.problem_status,
415 problem_type=problem.problem_type,
416 problem_name=problem.problem_name,
417 description=problem.description,
418 owner=user.username,
419 tags=problem.tags,
420 test_case=problem.test_case,
421 ).save()
423 @doc_required('target', Course, src_none_allowed=True)
424 def copy_to(
425 self,
426 user: User,
427 target: Optional[Course] = None,
428 **override,
429 ) -> 'Problem':
430 '''
431 Copy a problem to target course, hidden by default.
433 Args:
434 user (User): The user who execute this action and will become
435 the owner of copied problem.
436 target (Optional[Course] = None): The course this problem will
437 be copied to, default to the first of origial courses.
438 override: Override field values passed to `Problem.add`.
439 '''
440 target = self.courses[0] if target is None else target
441 # Copied problem is hidden by default
442 status = override.pop('status', Problem.engine.Visibility.HIDDEN)
443 ks = dict(
444 user=user,
445 courses=[target.course_name],
446 problem_name=self.problem_name,
447 status=status,
448 description=self.description.to_mongo(),
449 tags=self.tags,
450 type=self.problem_type,
451 test_case_info=self.test_case.to_mongo(),
452 can_view_stdout=self.can_view_stdout,
453 allowed_language=self.allowed_language,
454 quota=self.quota,
455 default_code=self.default_code,
456 )
457 ks.update(override)
458 copy = self.add(**ks)
459 return copy
461 @classmethod
462 def release_problem(cls, problem_id):
463 course = Course('Public').obj
464 problem = Problem(problem_id).obj
465 problem.courses = [course]
466 problem.owner = 'first_admin'
467 problem.save()
469 def is_test_case_ready(self) -> bool:
470 return (self.test_case.case_zip.grid_id is not None
471 or self.test_case.case_zip_minio_path is not None)
473 def get_test_case(self) -> BinaryIO:
474 if self.test_case.case_zip_minio_path is not None:
475 minio_client = MinioClient()
476 try:
477 resp = minio_client.client.get_object(
478 minio_client.bucket,
479 self.test_case.case_zip_minio_path,
480 )
481 return BytesIO(resp.read())
482 finally:
483 if 'resp' in locals():
484 resp.close()
485 resp.release_conn()
487 # fallback to legacy GridFS storage
488 return self.test_case.case_zip
490 def migrate_gridfs_to_minio(self):
491 '''
492 migrate test case from gridfs to minio
493 '''
494 if self.test_case.case_zip.grid_id is None:
495 self.logger.info(
496 f"no test case to migrate. problem={self.problem_id}")
497 return
499 if self.test_case.case_zip_minio_path is None:
500 self.logger.info(
501 f"uploading test case to minio. problem={self.problem_id}")
502 self._save_test_case_zip(self.test_case.case_zip)
503 self.logger.info(
504 f"test case uploaded to minio. problem={self.problem_id} path={self.test_case.case_zip_minio_path}"
505 )
507 if self.check_test_case_consistency():
508 self.logger.info(
509 f"removing test case in gridfs. problem={self.problem_id}")
510 self._remove_test_case_in_mongodb()
511 else:
512 self.logger.warning(
513 f"data inconsistent after migration, keeping test case in gridfs. problem={self.problem_id}"
514 )
516 def _remove_test_case_in_mongodb(self):
517 self.test_case.case_zip.delete()
518 self.save()
519 self.reload('test_case')
521 def check_test_case_consistency(self):
522 minio_client = MinioClient()
523 try:
524 resp = minio_client.client.get_object(
525 minio_client.bucket,
526 self.test_case.case_zip_minio_path,
527 )
528 minio_data = resp.read()
529 finally:
530 if 'resp' in locals():
531 resp.close()
532 resp.release_conn()
534 gridfs_data = self.test_case.case_zip.read()
535 if gridfs_data is None:
536 self.logger.warning(
537 f"gridfs test case is None but proxy is not updated. problem={self.problem_id}"
538 )
539 return False
541 minio_checksum = md5(minio_data).hexdigest()
542 gridfs_checksum = md5(gridfs_data).hexdigest()
544 self.logger.info(
545 f"calculated minio checksum. problem={self.problem_id} checksum={minio_checksum}"
546 )
547 self.logger.info(
548 f"calculated gridfs checksum. problem={self.problem_id} checksum={gridfs_checksum}"
549 )
551 return minio_checksum == gridfs_checksum
553 # TODO: hope minio SDK to provide more high-level API
554 def generate_urls_for_uploading_test_case(
555 self,
556 length: int,
557 part_size: int,
558 ) -> UploadInfo:
559 # TODO: update url after uploading completed
560 # TODO: handle failed uploading
561 path = self._generate_test_case_obj_path()
562 self.update(test_case__case_zip_minio_path=path)
564 minio_client = MinioClient()
565 upload_id = minio_client.client._create_multipart_upload(
566 minio_client.bucket,
567 path,
568 headers={'Content-Type': 'application/zip'},
569 )
570 part_count = (length + part_size - 1) // part_size
572 def get(i: int):
573 return minio_client.client.get_presigned_url(
574 'PUT',
575 minio_client.bucket,
576 path,
577 expires=timedelta(minutes=30),
578 extra_query_params={
579 'partNumber': str(i + 1),
580 'uploadId': upload_id
581 },
582 )
584 return UploadInfo(
585 urls=[get(i) for i in range(part_count)],
586 upload_id=upload_id,
587 )
589 def complete_test_case_upload(self, upload_id: str, parts: list):
590 minio_client = MinioClient()
591 minio_client.client._complete_multipart_upload(
592 minio_client.bucket,
593 self.test_case.case_zip_minio_path,
594 upload_id,
595 parts,
596 )
598 try:
599 test_case = self.get_test_case()
600 self._validate_test_case(test_case)
601 except BadTestCase:
602 self.update(test_case__case_zip_minio_path=None)
603 raise