Coverage for mongo/problem/problem.py: 79%

272 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-11 18:37 +0000

1import json 

2import enum 

3from hashlib import md5 

4from datetime import datetime, timedelta 

5from typing import ( 

6 Any, 

7 BinaryIO, 

8 Dict, 

9 List, 

10 Optional, 

11) 

12from dataclasses import dataclass 

13from io import BytesIO 

14from ulid import ULID 

15from .. import engine 

16from ..base import MongoBase 

17from ..course import * 

18from ..utils import (RedisCache, doc_required, drop_none, MinioClient) 

19from ..user import User 

20from .exception import BadTestCase 

21from .test_case import ( 

22 SimpleIO, 

23 ContextIO, 

24 IncludeDirectory, 

25 TestCaseRule, 

26) 

27 

28__all__ = ('Problem', ) 

29 

30 

31@dataclass 

32class UploadInfo: 

33 urls: List[str] 

34 upload_id: str 

35 

36 

37class Problem(MongoBase, engine=engine.Problem): 

38 

39 class Permission(enum.IntFlag): 

40 VIEW = enum.auto() # user view permission 

41 ONLINE = enum.auto() # user can view problem or not 

42 MANAGE = enum.auto() # user manage problem permission 

43 

44 def detailed_info(self, *ks, **kns) -> Dict[str, Any]: 

45 ''' 

46 return detailed info about this problem. notice 

47 that the `input` and `output` of problem test 

48 case won't be sent to front end, need call other 

49 route to get this info. 

50 

51 Args: 

52 ks (*str): the field name you want to get 

53 kns (**[str, str]): 

54 specify the dict key you want to store 

55 the data get by field name 

56 Return: 

57 a dict contains problem's data 

58 ''' 

59 if not self: 

60 return {} 

61 # problem -> dict 

62 _ret = self.to_mongo() 

63 # preprocess fields 

64 # case zip can not be serialized 

65 if 'caseZip' in _ret['testCase']: 

66 del _ret['testCase']['caseZip'] 

67 # skip minio path 

68 if 'caseZipMinioPath' in _ret['testCase']: 

69 del _ret['testCase']['caseZipMinioPath'] 

70 # convert couse document to course name 

71 _ret['courses'] = [course.course_name for course in self.courses] 

72 ret = {} 

73 for k in ks: 

74 kns[k] = k 

75 for k, n in kns.items(): 

76 s_ns = n.split('__') 

77 # extract wanted value 

78 v = _ret[s_ns[0]] 

79 for s_n in s_ns[1:]: 

80 v = v[s_n] 

81 # extract wanted keys 

82 e = ret 

83 s_ks = k.split('__') 

84 for s_k in s_ks[:-1]: 

85 if s_k not in e: 

86 e[s_k] = {} 

87 e = e[s_k] 

88 e[s_ks[-1]] = v 

89 return ret 

90 

91 def allowed(self, language): 

92 if self.problem_type == 2: 

93 return True 

94 if language >= 3 or language < 0: 

95 return False 

96 return bool((1 << language) & self.allowed_language) 

97 

98 def submit_count(self, user: User) -> int: 

99 ''' 

100 Calculate how many submissions the user has submitted to this problem. 

101 ''' 

102 # reset quota if it's a new day 

103 if user.last_submit.date() != datetime.now().date(): 

104 user.update(problem_submission={}) 

105 return 0 

106 return user.problem_submission.get(str(self.problem_id), 0) 

107 

108 def running_homeworks(self) -> List: 

109 from ..homework import Homework 

110 now = datetime.now() 

111 return [Homework(hw.id) for hw in self.homeworks if now in hw.duration] 

112 

113 def is_valid_ip(self, ip: str): 

114 return all(hw.is_valid_ip(ip) for hw in self.running_homeworks()) 

115 

116 def get_submission_status(self) -> Dict[str, int]: 

117 pipeline = { 

118 "$group": { 

119 "_id": "$status", 

120 "count": { 

121 "$sum": 1 

122 }, 

123 } 

124 } 

125 cursor = engine.Submission.objects(problem=self.id).aggregate( 

126 [pipeline], ) 

127 return {item['_id']: item['count'] for item in cursor} 

128 

129 def get_ac_user_count(self) -> int: 

130 ac_users = engine.Submission.objects( 

131 problem=self.id, 

132 status=0, 

133 ).distinct('user') 

134 return len(ac_users) 

135 

136 def get_tried_user_count(self) -> int: 

137 tried_users = engine.Submission.objects( 

138 problem=self.id, ).distinct('user') 

139 return len(tried_users) 

140 

141 @doc_required('user', User) 

142 def high_score_key(self, user: User) -> str: 

143 return f'PROBLEM_{self.id}_{user.id}_HIGH_SCORE' 

144 

145 @doc_required('user', User) 

146 def get_high_score(self, user: User) -> int: 

147 ''' 

148 Get highest score for user of this problem. 

149 ''' 

150 cache = RedisCache() 

151 key = self.high_score_key(user=user) 

152 if (val := cache.get(key)) is not None: 

153 return int(val.decode()) 

154 # TODO: avoid calling mongoengine API directly 

155 submissions = engine.Submission.objects( 

156 user=user.id, 

157 problem=self.id, 

158 ).only('score').order_by('-score').limit(1) 

159 if submissions.count() == 0: 

160 high_score = 0 

161 else: 

162 # It might < 0 if there is only incomplete submission 

163 high_score = max(submissions[0].score, 0) 

164 cache.set(key, high_score, ex=600) 

165 return high_score 

166 

167 @doc_required('user', User) 

168 def own_permission(self, user: User) -> Permission: 

169 """ 

170 generate user permission capability 

171 """ 

172 

173 user_cap = self.Permission(0) 

174 for course in map(Course, self.courses): 

175 # inherit course permission 

176 if course.permission(user, Course.Permission.VIEW): 

177 user_cap |= self.Permission.VIEW 

178 

179 # online problem 

180 if self.problem_status == 0: 

181 check_public_problem = True 

182 for homework in course.homeworks: 

183 if self.problem_id in homework.problem_ids: 

184 check_public_problem = False 

185 # current time after homework then online problem 

186 if datetime.now() >= homework.duration.start: 

187 user_cap |= self.Permission.ONLINE 

188 

189 # problem does not belong to any homework 

190 if check_public_problem: 

191 user_cap |= self.Permission.ONLINE 

192 

193 # Admin, Teacher && is owner 

194 if user.role == 0 or self.owner == user.username: 

195 user_cap |= self.Permission.VIEW 

196 user_cap |= self.Permission.ONLINE 

197 user_cap |= self.Permission.MANAGE 

198 

199 return user_cap 

200 

201 def permission(self, user: User, req: Permission) -> bool: 

202 """ 

203 check whether user own `req` permission 

204 """ 

205 

206 return (self.own_permission(user=user) & req) == req 

207 

208 @classmethod 

209 def get_problem_list( 

210 cls, 

211 user, 

212 offset: int = 0, 

213 count: int = -1, 

214 problem_id: int = None, 

215 name: str = None, 

216 tags: list = None, 

217 course: str = None, 

218 ): 

219 ''' 

220 get a list of problems 

221 ''' 

222 if course is not None: 

223 course = Course(course) 

224 if not course: 

225 return [] 

226 course = course.obj 

227 # qurey args 

228 ks = drop_none({ 

229 'problem_id': problem_id, 

230 'problem_name': name, 

231 'courses': course, 

232 'tags__in': tags, 

233 }) 

234 problems = [ 

235 p for p in engine.Problem.objects(**ks).order_by('problemId') 

236 if cls(p).permission(user=user, req=cls.Permission.ONLINE) 

237 ] 

238 # truncate 

239 if offset < 0 or (offset >= len(problems) and len(problems)): 

240 raise IndexError 

241 right = len(problems) if count < 0 else offset + count 

242 right = min(len(problems), right) 

243 return problems[offset:right] 

244 

245 @classmethod 

246 def add( 

247 cls, 

248 user: User, 

249 courses: List[str], 

250 problem_name: str, 

251 status: Optional[int] = None, 

252 description: Optional[Dict[str, Any]] = None, 

253 tags: Optional[List[str]] = None, 

254 type: Optional[int] = None, 

255 test_case_info: Optional[Dict[str, Any]] = None, 

256 can_view_stdout: bool = False, 

257 allowed_language: Optional[int] = None, 

258 quota: Optional[int] = None, 

259 default_code: Optional[str] = None, 

260 ): 

261 if len(courses) == 0: 

262 raise ValueError('No course provided') 

263 course_objs = [] 

264 for course in map(Course, courses): 

265 if not course: 

266 raise engine.DoesNotExist 

267 course_objs.append(course.id) 

268 problem_args = drop_none({ 

269 'courses': course_objs, 

270 'problem_status': status, 

271 'problem_type': type, 

272 'problem_name': problem_name, 

273 'description': description, 

274 'owner': user.username, 

275 'tags': tags, 

276 'quota': quota, 

277 'default_code': default_code, 

278 }) 

279 problem = cls.engine(**problem_args).save() 

280 programming_problem_args = drop_none({ 

281 'test_case': 

282 test_case_info, 

283 'can_view_stdout': 

284 can_view_stdout, 

285 'allowed_language': 

286 allowed_language, 

287 }) 

288 if programming_problem_args and type != 2: 

289 problem.update(**programming_problem_args) 

290 return problem.problem_id 

291 

292 @classmethod 

293 def edit_problem( 

294 cls, 

295 user: User, 

296 problem_id: int, 

297 courses: List[str], 

298 status: int, 

299 problem_name: str, 

300 description: Dict[str, Any], 

301 tags: List[str], 

302 type, 

303 test_case_info: Optional[Dict[str, Any]] = None, 

304 allowed_language: int = 7, 

305 can_view_stdout: bool = False, 

306 quota: int = -1, 

307 default_code: str = '', 

308 ): 

309 if type != 2: 

310 score = sum(t['taskScore'] for t in test_case_info['tasks']) 

311 if score != 100: 

312 raise ValueError("Cases' scores should be 100 in total") 

313 problem = Problem(problem_id).obj 

314 course_objs = [] 

315 for name in courses: 

316 if not (course := Course(name)): 

317 raise engine.DoesNotExist 

318 course_objs.append(course.obj) 

319 problem.update( 

320 courses=course_objs, 

321 problem_status=status, 

322 problem_type=type, 

323 problem_name=problem_name, 

324 description=description, 

325 owner=user.username, 

326 tags=tags, 

327 quota=quota, 

328 default_code=default_code, 

329 ) 

330 if type != 2: 

331 # preprocess test case 

332 test_case = problem.test_case 

333 if test_case_info: 

334 test_case = engine.ProblemTestCase.from_json( 

335 json.dumps(test_case_info)) 

336 test_case.case_zip = problem.test_case.case_zip 

337 test_case.case_zip_minio_path = problem.test_case.case_zip_minio_path 

338 problem.update( 

339 allowed_language=allowed_language, 

340 can_view_stdout=can_view_stdout, 

341 test_case=test_case, 

342 ) 

343 

344 def update_test_case(self, test_case: BinaryIO): 

345 ''' 

346 edit problem's testcase 

347 

348 Args: 

349 test_case: testcase zip file 

350 Exceptions: 

351 zipfile.BadZipFile: if `test_case` is not a zip file 

352 ValueError: if test case is None or problem_id is invalid 

353 engine.DoesNotExist 

354 ''' 

355 self._validate_test_case(test_case) 

356 test_case.seek(0) 

357 self._save_test_case_zip(test_case) 

358 

359 def _save_test_case_zip(self, test_case: BinaryIO): 

360 ''' 

361 save test case zip file 

362 ''' 

363 minio_client = MinioClient() 

364 path = self._generate_test_case_obj_path() 

365 minio_client.client.put_object( 

366 minio_client.bucket, 

367 path, 

368 test_case, 

369 -1, 

370 part_size=5 * 1024 * 1024, 

371 content_type='application/zip', 

372 ) 

373 self.update(test_case__case_zip_minio_path=path) 

374 self.reload('test_case') 

375 

376 def _generate_test_case_obj_path(self): 

377 return f'problem-test-case/{ULID()}.zip' 

378 

379 def _validate_test_case(self, test_case: BinaryIO): 

380 ''' 

381 validate test case, raise BadTestCase if invalid 

382 ''' 

383 rules: List[TestCaseRule] = [ 

384 IncludeDirectory(self, 'include'), 

385 IncludeDirectory(self, 'share'), 

386 # for backward compatibility 

387 IncludeDirectory(self, 'chaos'), 

388 ] 

389 for rule in rules: 

390 rule.validate(test_case) 

391 

392 # Should only match one format 

393 rules = [ 

394 SimpleIO(self, ['include/', 'share/', 'chaos/']), 

395 ContextIO(self), 

396 ] 

397 excs = [] 

398 for rule in rules: 

399 try: 

400 rule.validate(test_case) 

401 except BadTestCase as e: 

402 excs.append(e) 

403 

404 if len(excs) == 0: 

405 raise BadTestCase('ambiguous test case format') 

406 elif len(excs) == 2: 

407 raise BadTestCase( 

408 f'invalid test case format\n\n{excs[0]}\n\n{excs[1]}') 

409 

410 @classmethod 

411 def copy_problem(cls, user, problem_id): 

412 problem = Problem(problem_id).obj 

413 engine.Problem( 

414 problem_status=problem.problem_status, 

415 problem_type=problem.problem_type, 

416 problem_name=problem.problem_name, 

417 description=problem.description, 

418 owner=user.username, 

419 tags=problem.tags, 

420 test_case=problem.test_case, 

421 ).save() 

422 

423 @doc_required('target', Course, src_none_allowed=True) 

424 def copy_to( 

425 self, 

426 user: User, 

427 target: Optional[Course] = None, 

428 **override, 

429 ) -> 'Problem': 

430 ''' 

431 Copy a problem to target course, hidden by default. 

432 

433 Args: 

434 user (User): The user who execute this action and will become 

435 the owner of copied problem. 

436 target (Optional[Course] = None): The course this problem will 

437 be copied to, default to the first of origial courses. 

438 override: Override field values passed to `Problem.add`. 

439 ''' 

440 target = self.courses[0] if target is None else target 

441 # Copied problem is hidden by default 

442 status = override.pop('status', Problem.engine.Visibility.HIDDEN) 

443 ks = dict( 

444 user=user, 

445 courses=[target.course_name], 

446 problem_name=self.problem_name, 

447 status=status, 

448 description=self.description.to_mongo(), 

449 tags=self.tags, 

450 type=self.problem_type, 

451 test_case_info=self.test_case.to_mongo(), 

452 can_view_stdout=self.can_view_stdout, 

453 allowed_language=self.allowed_language, 

454 quota=self.quota, 

455 default_code=self.default_code, 

456 ) 

457 ks.update(override) 

458 copy = self.add(**ks) 

459 return copy 

460 

461 @classmethod 

462 def release_problem(cls, problem_id): 

463 course = Course('Public').obj 

464 problem = Problem(problem_id).obj 

465 problem.courses = [course] 

466 problem.owner = 'first_admin' 

467 problem.save() 

468 

469 def is_test_case_ready(self) -> bool: 

470 return (self.test_case.case_zip.grid_id is not None 

471 or self.test_case.case_zip_minio_path is not None) 

472 

473 def get_test_case(self) -> BinaryIO: 

474 if self.test_case.case_zip_minio_path is not None: 

475 minio_client = MinioClient() 

476 try: 

477 resp = minio_client.client.get_object( 

478 minio_client.bucket, 

479 self.test_case.case_zip_minio_path, 

480 ) 

481 return BytesIO(resp.read()) 

482 finally: 

483 if 'resp' in locals(): 

484 resp.close() 

485 resp.release_conn() 

486 

487 # fallback to legacy GridFS storage 

488 return self.test_case.case_zip 

489 

490 def migrate_gridfs_to_minio(self): 

491 ''' 

492 migrate test case from gridfs to minio 

493 ''' 

494 if self.test_case.case_zip.grid_id is None: 

495 self.logger.info( 

496 f"no test case to migrate. problem={self.problem_id}") 

497 return 

498 

499 if self.test_case.case_zip_minio_path is None: 

500 self.logger.info( 

501 f"uploading test case to minio. problem={self.problem_id}") 

502 self._save_test_case_zip(self.test_case.case_zip) 

503 self.logger.info( 

504 f"test case uploaded to minio. problem={self.problem_id} path={self.test_case.case_zip_minio_path}" 

505 ) 

506 

507 if self.check_test_case_consistency(): 

508 self.logger.info( 

509 f"removing test case in gridfs. problem={self.problem_id}") 

510 self._remove_test_case_in_mongodb() 

511 else: 

512 self.logger.warning( 

513 f"data inconsistent after migration, keeping test case in gridfs. problem={self.problem_id}" 

514 ) 

515 

516 def _remove_test_case_in_mongodb(self): 

517 self.test_case.case_zip.delete() 

518 self.save() 

519 self.reload('test_case') 

520 

521 def check_test_case_consistency(self): 

522 minio_client = MinioClient() 

523 try: 

524 resp = minio_client.client.get_object( 

525 minio_client.bucket, 

526 self.test_case.case_zip_minio_path, 

527 ) 

528 minio_data = resp.read() 

529 finally: 

530 if 'resp' in locals(): 

531 resp.close() 

532 resp.release_conn() 

533 

534 gridfs_data = self.test_case.case_zip.read() 

535 if gridfs_data is None: 

536 self.logger.warning( 

537 f"gridfs test case is None but proxy is not updated. problem={self.problem_id}" 

538 ) 

539 return False 

540 

541 minio_checksum = md5(minio_data).hexdigest() 

542 gridfs_checksum = md5(gridfs_data).hexdigest() 

543 

544 self.logger.info( 

545 f"calculated minio checksum. problem={self.problem_id} checksum={minio_checksum}" 

546 ) 

547 self.logger.info( 

548 f"calculated gridfs checksum. problem={self.problem_id} checksum={gridfs_checksum}" 

549 ) 

550 

551 return minio_checksum == gridfs_checksum 

552 

553 # TODO: hope minio SDK to provide more high-level API 

554 def generate_urls_for_uploading_test_case( 

555 self, 

556 length: int, 

557 part_size: int, 

558 ) -> UploadInfo: 

559 # TODO: update url after uploading completed 

560 # TODO: handle failed uploading 

561 path = self._generate_test_case_obj_path() 

562 self.update(test_case__case_zip_minio_path=path) 

563 

564 minio_client = MinioClient() 

565 upload_id = minio_client.client._create_multipart_upload( 

566 minio_client.bucket, 

567 path, 

568 headers={'Content-Type': 'application/zip'}, 

569 ) 

570 part_count = (length + part_size - 1) // part_size 

571 

572 def get(i: int): 

573 return minio_client.client.get_presigned_url( 

574 'PUT', 

575 minio_client.bucket, 

576 path, 

577 expires=timedelta(minutes=30), 

578 extra_query_params={ 

579 'partNumber': str(i + 1), 

580 'uploadId': upload_id 

581 }, 

582 ) 

583 

584 return UploadInfo( 

585 urls=[get(i) for i in range(part_count)], 

586 upload_id=upload_id, 

587 ) 

588 

589 def complete_test_case_upload(self, upload_id: str, parts: list): 

590 minio_client = MinioClient() 

591 minio_client.client._complete_multipart_upload( 

592 minio_client.bucket, 

593 self.test_case.case_zip_minio_path, 

594 upload_id, 

595 parts, 

596 ) 

597 

598 try: 

599 test_case = self.get_test_case() 

600 self._validate_test_case(test_case) 

601 except BadTestCase: 

602 self.update(test_case__case_zip_minio_path=None) 

603 raise